Camera calibration

In this notebook, we solve a camera calibration problem: estimating intrinsic and extrinsic parameters from checkerboard observations.

Inputs: Images of checkerboard pattern (OpenCV sample data)
Outputs: Camera intrinsics (focal length, principal point, distortion coefficients)

Features used:

  • Var subclass for camera intrinsics

  • SE3Var for camera extrinsic poses

  • @jaxls.Cost.factory for reprojection error

  • OpenCV for chessboard corner detection

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import urllib.request
from pathlib import Path

import cv2
import jax
import jax.numpy as jnp
import jaxls
import jaxlie
import numpy as np
from scipy import ndimage

Download OpenCV sample images

Download the calibration images from the OpenCV repository:

def download_calibration_images(
    cache_dir: Path = Path("/tmp/opencv_calib"),
) -> list[Path]:
    """Download OpenCV sample calibration images.

    Args:
        cache_dir: Directory to cache downloaded images

    Returns:
        List of paths to downloaded image files
    """
    cache_dir.mkdir(parents=True, exist_ok=True)
    base_url = "https://raw.githubusercontent.com/opencv/opencv/master/samples/data"

    # Note: left10.jpg doesn't exist in the OpenCV repo.
    image_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14]

    image_paths = []
    for i in image_indices:
        filename = f"left{i:02d}.jpg"
        local_path = cache_dir / filename

        if not local_path.exists():
            url = f"{base_url}/{filename}"
            logger.info(f"Downloading {filename}...")
            urllib.request.urlretrieve(url, local_path)

        image_paths.append(local_path)

    return image_paths


image_paths = download_calibration_images()
print(f"Downloaded {len(image_paths)} calibration images")
INFO     | Downloading left01.jpg...
INFO     | Downloading left02.jpg...
INFO     | Downloading left03.jpg...
INFO     | Downloading left04.jpg...
INFO     | Downloading left05.jpg...
INFO     | Downloading left06.jpg...
INFO     | Downloading left07.jpg...
INFO     | Downloading left08.jpg...
INFO     | Downloading left09.jpg...
INFO     | Downloading left11.jpg...
INFO     | Downloading left12.jpg...
INFO     | Downloading left13.jpg...
INFO     | Downloading left14.jpg...
Downloaded 13 calibration images

Detect chessboard corners

Use OpenCV to detect the inner corners of the 9x6 chessboard pattern:

# Chessboard parameters: 9x6 inner corners.
board_cols, board_rows = 9, 6
square_size = 0.025  # 25mm squares (approximate)

# 3D checkerboard points (on Z=0 plane)
board_points_3d = np.zeros((board_rows * board_cols, 3), np.float32)
board_points_3d[:, :2] = (
    np.mgrid[0:board_cols, 0:board_rows].T.reshape(-1, 2) * square_size
)
board_points_3d = jnp.array(board_points_3d)

print(f"Chessboard: {board_cols}x{board_rows} = {len(board_points_3d)} corners")
print(
    f"Board size: {board_cols * square_size * 1000:.0f}mm x {board_rows * square_size * 1000:.0f}mm"
)
Chessboard: 9x6 = 54 corners
Board size: 225mm x 150mm
# Detect corners in all images.
observations_2d: list[jax.Array] = []
valid_image_indices: list[int] = []
image_size = None

criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.001)

for i, path in enumerate(image_paths):
    img = cv2.imread(str(path))
    if image_size is None:
        image_size = (img.shape[1], img.shape[0])  # (width, height)

    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    ret, corners = cv2.findChessboardCorners(gray, (board_cols, board_rows), None)

    if ret:
        # Refine corner positions.
        corners_refined = cv2.cornerSubPix(gray, corners, (11, 11), (-1, -1), criteria)
        observations_2d.append(jnp.array(corners_refined.squeeze()))
        valid_image_indices.append(i)
        print(f"  Image {i + 1:2d}: ✓ Found {len(corners)} corners")
    else:
        print(f"  Image {i + 1:2d}: ✗ Chessboard not found")

num_views = len(observations_2d)
print(f"\nSuccessfully detected corners in {num_views}/{len(image_paths)} images")
print(f"Image size: {image_size[0]}x{image_size[1]}")
  Image  1: ✓ Found 54 corners
  Image  2: ✓ Found 54 corners
  Image  3: ✓ Found 54 corners
  Image  4: ✓ Found 54 corners
  Image  5: ✓ Found 54 corners
  Image  6: ✓ Found 54 corners
  Image  7: ✓ Found 54 corners
  Image  8: ✓ Found 54 corners
  Image  9: ✓ Found 54 corners
  Image 10: ✓ Found 54 corners
  Image 11: ✓ Found 54 corners
  Image 12: ✓ Found 54 corners
  Image 13: ✓ Found 54 corners

Successfully detected corners in 13/13 images
Image size: 640x480

Hide code cell source

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML

# Show sample input images with detected corners.
sample_indices = [0, 3, 6]  # Show 3 sample images

fig_samples = make_subplots(
    rows=1,
    cols=len(sample_indices),
    subplot_titles=[f"Image {valid_image_indices[i] + 1}" for i in sample_indices],
)

for col, idx in enumerate(sample_indices):
    img = cv2.imread(str(image_paths[valid_image_indices[idx]]))
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Draw detected corners.
    corners = np.array(observations_2d[idx])
    cv2.drawChessboardCorners(
        img_rgb, (board_cols, board_rows), corners.reshape(-1, 1, 2), True
    )

    fig_samples.add_trace(
        go.Image(z=img_rgb),
        row=1,
        col=col + 1,
    )

fig_samples.update_xaxes(showticklabels=False)
fig_samples.update_yaxes(showticklabels=False)
fig_samples.update_layout(
    height=280,
    margin=dict(t=40, b=20, l=20, r=20),
)
HTML(fig_samples.to_html(full_html=False, include_plotlyjs="cdn"))

Camera model

We use the Brown-Conrady distortion model (same as OpenCV):

\[x' = x_n (1 + k_1 r^2 + k_2 r^4) + 2 p_1 x_n y_n + p_2 (r^2 + 2 x_n^2)\]
\[y' = y_n (1 + k_1 r^2 + k_2 r^4) + p_1 (r^2 + 2 y_n^2) + 2 p_2 x_n y_n\]
\[u = f_x \cdot x' + c_x, \quad v = f_y \cdot y' + c_y\]

where \((x_n, y_n)\) are normalized coordinates, \(r^2 = x_n^2 + y_n^2\), \((k_1, k_2)\) are radial distortion coefficients, and \((p_1, p_2)\) are tangential distortion coefficients.

class IntrinsicsVar(
    jaxls.Var[jax.Array],
    default_factory=lambda: jnp.array([500.0, 500.0, 320.0, 240.0, 0.0, 0.0, 0.0, 0.0]),
):
    """Camera intrinsics: [fx, fy, cx, cy, k1, k2, p1, p2]."""


@jax.jit
def project_brown_conrady(
    points_camera: jax.Array,  # (N, 3) points in camera frame
    intrinsics: jax.Array,  # [fx, fy, cx, cy, k1, k2, p1, p2]
) -> jax.Array:
    """Project 3D points to 2D using Brown-Conrady distortion model.

    Args:
        points_camera: 3D points in camera frame (N, 3)
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]

    Returns:
        2D projected points (N, 2)
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = intrinsics

    x, y, z = points_camera[..., 0], points_camera[..., 1], points_camera[..., 2]

    # Avoid division by zero.
    z_safe = jnp.maximum(z, 1e-6)

    # Normalized coordinates.
    x_n = x / z_safe
    y_n = y / z_safe

    # Radial distortion.
    r2 = x_n**2 + y_n**2
    radial = 1.0 + k1 * r2 + k2 * r2**2

    # Tangential distortion.
    x_d = x_n * radial + 2 * p1 * x_n * y_n + p2 * (r2 + 2 * x_n**2)
    y_d = y_n * radial + p1 * (r2 + 2 * y_n**2) + 2 * p2 * x_n * y_n

    # Pixel coordinates.
    u = fx * x_d + cx
    v = fy * y_d + cy

    return jnp.stack([u, v], axis=-1)

Problem construction

We optimize camera intrinsics and all extrinsic poses jointly using reprojection error:

# Variables.
intrinsics_var = IntrinsicsVar(id=0)
pose_vars = [jaxls.SE3Var(id=i) for i in range(num_views)]
@jaxls.Cost.factory
def reprojection_cost(
    vals: jaxls.VarValues,
    intrinsics_var: IntrinsicsVar,
    pose_var: jaxls.SE3Var,
    points_3d: jax.Array,  # (N, 3) batch of 3D points
    observed_2d: jax.Array,  # (N, 2) batch of observed 2D points
) -> jax.Array:
    """Reprojection error for a batch of points in a single view."""
    intrinsics = vals[intrinsics_var]
    pose = vals[pose_var]

    # Transform all points to camera frame and project.
    points_camera = jax.vmap(pose.apply)(points_3d)
    projected = project_brown_conrady(points_camera, intrinsics)

    return (projected - observed_2d).flatten()
# Build costs using batched construction - one cost per view.
costs: list[jaxls.Cost] = [
    reprojection_cost(
        intrinsics_var,
        pose_vars[view_idx],
        board_points_3d,  # All 3D points
        observations_2d[view_idx],  # All 2D observations for this view
    )
    for view_idx in range(num_views)
]

print(f"Created {len(costs)} batched reprojection costs ({num_views} views)")
Created 13 batched reprojection costs (13 views)
# Initialize intrinsics with reasonable guesses.
# Focal length ~ image width, principal point ~ image center.
init_fx = float(image_size[0]) / 2
init_fy = float(image_size[0]) / 2
init_cx = float(image_size[0]) / 2
init_cy = float(image_size[1]) / 2
init_intrinsics = jnp.array([init_fx, init_fy, init_cx, init_cy, 0.0, 0.0, 0.0, 0.0])

print(
    f"Initial intrinsics: fx={init_fx:.0f}, fy={init_fy:.0f}, cx={init_cx:.0f}, cy={init_cy:.0f}"
)
Initial intrinsics: fx=320, fy=320, cx=320, cy=240
def estimate_initial_pose(
    observed_corners: jax.Array, intrinsics: jax.Array
) -> jaxlie.SE3:
    """Estimate initial pose using OpenCV's solvePnP.

    Args:
        observed_corners: Detected 2D corner positions (N, 2)
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]

    Returns:
        Estimated camera pose as SE3
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = intrinsics
    camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)
    dist_coeffs = np.array([k1, k2, p1, p2], dtype=np.float64)

    _, rvec, tvec = cv2.solvePnP(
        np.array(board_points_3d),
        np.array(observed_corners),
        camera_matrix,
        dist_coeffs,
    )

    R, _ = cv2.Rodrigues(rvec)
    rotation = jaxlie.SO3.from_matrix(jnp.array(R))
    translation = jnp.array(tvec.squeeze())

    return jaxlie.SE3.from_rotation_and_translation(rotation, translation)


# Estimate initial poses.
init_poses = [estimate_initial_pose(obs, init_intrinsics) for obs in observations_2d]
print(f"Estimated {len(init_poses)} initial poses using PnP")
Estimated 13 initial poses using PnP
# Create initial values.
initial_vals = jaxls.VarValues.make(
    [intrinsics_var.with_value(init_intrinsics)]
    + [pose_vars[i].with_value(init_poses[i]) for i in range(num_views)]
)

# Create and analyze problem.
problem = jaxls.LeastSquaresProblem(costs, [intrinsics_var] + pose_vars).analyze()
INFO     | Building optimization problem with 13 terms and 14 variables: 13 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 13 costs, 2 variables each: reprojection_cost

Solving

solution = problem.solve(initial_vals)
INFO     |  step #0: cost=11776.8213 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 11776.82129 (avg 8.38805)
INFO     |      accepted=True ATb_norm=3.08e+04 cost_prev=11776.8213 cost_new=118103.0781
INFO     |  step #1: cost=118103.0781 lambd=0.0003 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 118103.07812 (avg 84.11900)
INFO     |      accepted=True ATb_norm=4.23e+06 cost_prev=118103.0781 cost_new=630.2039
INFO     |  step #2: cost=630.2039 lambd=0.0001 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #3: cost=630.2039 lambd=0.0003 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #4: cost=630.2039 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #5: cost=630.2039 lambd=0.0010 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #6: cost=630.2039 lambd=0.0020 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #7: cost=630.2039 lambd=0.0040 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #8: cost=630.2039 lambd=0.0080 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #9: cost=630.2039 lambd=0.0160 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |      accepted=True ATb_norm=9.59e+04 cost_prev=630.2039 cost_new=236.5836
INFO     |  step #10: cost=236.5836 lambd=0.0080 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 236.58362 (avg 0.16851)
INFO     |      accepted=True ATb_norm=1.02e+05 cost_prev=236.5836 cost_new=117.5282
INFO     |  step #11: cost=117.5282 lambd=0.0040 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 117.52823 (avg 0.08371)
INFO     |  step #12: cost=117.5282 lambd=0.0080 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 117.52823 (avg 0.08371)
INFO     |  step #13: cost=117.5282 lambd=0.0160 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 117.52823 (avg 0.08371)
INFO     |  step #14: cost=117.5282 lambd=0.0320 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 117.52823 (avg 0.08371)
INFO     |  step #15: cost=117.5282 lambd=0.0640 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 117.52823 (avg 0.08371)
INFO     |  step #16: cost=117.5282 lambd=0.1280 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 117.52823 (avg 0.08371)
INFO     |      accepted=True ATb_norm=1.61e+03 cost_prev=117.5282 cost_new=117.4151
INFO     |  step #17: cost=117.4151 lambd=0.0640 inexact_tol=2.2e-04
INFO     |      - reprojection_cost(13): 117.41513 (avg 0.08363)
INFO     |  step #18: cost=117.4151 lambd=0.1280 inexact_tol=2.2e-04
INFO     |      - reprojection_cost(13): 117.41513 (avg 0.08363)
INFO     |  step #19: cost=117.4151 lambd=0.2560 inexact_tol=2.2e-04
INFO     |      - reprojection_cost(13): 117.41513 (avg 0.08363)
INFO     |  step #20: cost=117.4151 lambd=0.5120 inexact_tol=2.2e-04
INFO     |      - reprojection_cost(13): 117.41513 (avg 0.08363)
INFO     |      accepted=True ATb_norm=1.46e+01 cost_prev=117.4151 cost_new=117.4097
INFO     |  step #21: cost=117.4097 lambd=0.2560 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40974 (avg 0.08363)
INFO     |  step #22: cost=117.4097 lambd=0.5120 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40974 (avg 0.08363)
INFO     |      accepted=True ATb_norm=7.08e+00 cost_prev=117.4097 cost_new=117.4080
INFO     |  step #23: cost=117.4080 lambd=0.2560 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40798 (avg 0.08362)
INFO     |  step #24: cost=117.4080 lambd=0.5120 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40798 (avg 0.08362)
INFO     |      accepted=True ATb_norm=8.25e+00 cost_prev=117.4080 cost_new=117.4064
INFO     |  step #25: cost=117.4064 lambd=0.2560 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40639 (avg 0.08362)
INFO     |  step #26: cost=117.4064 lambd=0.5120 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40639 (avg 0.08362)
INFO     |      accepted=True ATb_norm=5.53e+00 cost_prev=117.4064 cost_new=117.4042
INFO     |  step #27: cost=117.4042 lambd=0.2560 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40421 (avg 0.08362)
INFO     |  step #28: cost=117.4042 lambd=0.5120 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40421 (avg 0.08362)
INFO     |      accepted=True ATb_norm=5.43e+00 cost_prev=117.4042 cost_new=117.4032
INFO     | Terminated @ iteration #29: cost=117.4032 criteria=[1 0 0], term_deltas=8.9e-06,2.9e+00,3.7e-05
# Compare results.
est_intrinsics = solution[intrinsics_var]

print("Estimated intrinsics:")
param_names = ["fx", "fy", "cx", "cy", "k1", "k2", "p1", "p2"]
print(f"  {'Parameter':<12} {'Initial':>12} {'Estimated':>12}")
print(f"  {'-' * 38}")
for i, name in enumerate(param_names):
    init, est = init_intrinsics[i], est_intrinsics[i]
    print(f"  {name:<12} {float(init):>12.4f} {float(est):>12.4f}")
Estimated intrinsics:
  Parameter         Initial    Estimated
  --------------------------------------
  fx               320.0000     536.4322
  fy               320.0000     536.3876
  cx               320.0000     342.2786
  cy               240.0000     235.6965
  k1                 0.0000      -0.2786
  k2                 0.0000       0.0673
  p1                 0.0000       0.0018
  p2                 0.0000      -0.0003

Visualization

Hide code cell source

def compute_reprojection_errors(
    intrinsics: jax.Array, poses: list[jaxlie.SE3]
) -> tuple[list[jax.Array], list[jax.Array]]:
    """Compute reprojection errors for all views.

    Args:
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]
        poses: List of camera poses (one per view)

    Returns:
        Tuple of (projected_points, errors) where each is a list per view
    """
    all_projected = []
    all_errors = []
    for i, pose in enumerate(poses):
        points_camera = jax.vmap(pose.apply)(board_points_3d)
        projected = project_brown_conrady(points_camera, intrinsics)
        errors = jnp.linalg.norm(projected - observations_2d[i], axis=-1)
        all_projected.append(projected)
        all_errors.append(errors)
    return all_projected, all_errors


# Compute errors before and after.
init_projected, init_errors = compute_reprojection_errors(init_intrinsics, init_poses)
est_poses = [solution[pose_vars[i]] for i in range(num_views)]
est_projected, est_errors = compute_reprojection_errors(est_intrinsics, est_poses)

init_rmse = float(jnp.sqrt(jnp.mean(jnp.concatenate([e**2 for e in init_errors]))))
est_rmse = float(jnp.sqrt(jnp.mean(jnp.concatenate([e**2 for e in est_errors]))))
print(
    f"Reprojection RMSE: {init_rmse:.3f} px (initial) -> {est_rmse:.3f} px (optimized)"
)
Reprojection RMSE: 4.096 px (initial) -> 0.409 px (optimized)

Hide code cell source

# Reprojection error distribution.
all_init_errors = jnp.concatenate(init_errors)
all_est_errors = jnp.concatenate(est_errors)

fig_errors = go.Figure()
fig_errors.add_trace(
    go.Histogram(
        x=all_init_errors,
        name="Initial",
        marker_color="tomato",
        opacity=0.7,
        nbinsx=30,
    )
)
fig_errors.add_trace(
    go.Histogram(
        x=all_est_errors,
        name="Optimized",
        marker_color="steelblue",
        opacity=0.7,
        nbinsx=30,
    )
)
fig_errors.update_layout(
    title="Reprojection Error Distribution",
    xaxis_title="Error (pixels)",
    yaxis_title="Count",
    barmode="overlay",
    height=300,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
)
HTML(fig_errors.to_html(full_html=False, include_plotlyjs="cdn"))

Hide code cell source

# Camera poses (top-down view)
cam_positions = [pose.inverse().translation() for pose in est_poses]
cam_x = [float(p[0]) for p in cam_positions]
cam_y = [float(p[1]) for p in cam_positions]

# Chessboard outline.
board_corners = jnp.array(
    [
        [0, 0, 0],
        [board_cols * square_size, 0, 0],
        [board_cols * square_size, board_rows * square_size, 0],
        [0, board_rows * square_size, 0],
        [0, 0, 0],
    ]
)

fig_poses = go.Figure()
fig_poses.add_trace(
    go.Scatter(
        x=board_corners[:, 0] * 1000,
        y=board_corners[:, 1] * 1000,
        mode="lines",
        line=dict(color="gray", width=2),
        name="Board",
    )
)
fig_poses.add_trace(
    go.Scatter(
        x=[c * 1000 for c in cam_x],
        y=[c * 1000 for c in cam_y],
        mode="markers+text",
        marker=dict(size=10, color="steelblue"),
        text=[str(i + 1) for i in range(num_views)],
        textposition="top center",
        name="Cameras",
    )
)
fig_poses.update_layout(
    title="Camera Poses (top view)",
    xaxis_title="X (mm)",
    yaxis_title="Y (mm)",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    height=350,
    margin=dict(t=40, b=40, l=60, r=40),
    showlegend=False,
)
HTML(fig_poses.to_html(full_html=False, include_plotlyjs="cdn"))

Hide code cell source

# Single view comparison: initial vs optimized reprojection.
view_idx = 1
obs = observations_2d[view_idx]

fig_view = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=(
        f"Initial (RMSE={float(jnp.sqrt(jnp.mean(init_errors[view_idx] ** 2))):.2f}px)",
        f"Optimized (RMSE={float(jnp.sqrt(jnp.mean(est_errors[view_idx] ** 2))):.2f}px)",
    ),
)

# Initial projection.
fig_view.add_trace(
    go.Scatter(
        x=obs[:, 0],
        y=obs[:, 1],
        mode="markers",
        marker=dict(size=8, color="green", symbol="circle"),
        name="Observed",
        showlegend=True,
    ),
    row=1,
    col=1,
)
fig_view.add_trace(
    go.Scatter(
        x=init_projected[view_idx][:, 0],
        y=init_projected[view_idx][:, 1],
        mode="markers",
        marker=dict(size=6, color="tomato", symbol="x"),
        name="Projected",
        showlegend=True,
    ),
    row=1,
    col=1,
)
for j in range(len(obs)):
    fig_view.add_trace(
        go.Scatter(
            x=[obs[j, 0], init_projected[view_idx][j, 0]],
            y=[obs[j, 1], init_projected[view_idx][j, 1]],
            mode="lines",
            line=dict(color="tomato", width=0.5),
            showlegend=False,
            hoverinfo="skip",
        ),
        row=1,
        col=1,
    )

# Optimized projection.
fig_view.add_trace(
    go.Scatter(
        x=obs[:, 0],
        y=obs[:, 1],
        mode="markers",
        marker=dict(size=8, color="green", symbol="circle"),
        showlegend=False,
    ),
    row=1,
    col=2,
)
fig_view.add_trace(
    go.Scatter(
        x=est_projected[view_idx][:, 0],
        y=est_projected[view_idx][:, 1],
        mode="markers",
        marker=dict(size=6, color="steelblue", symbol="x"),
        showlegend=False,
    ),
    row=1,
    col=2,
)
for j in range(len(obs)):
    fig_view.add_trace(
        go.Scatter(
            x=[obs[j, 0], est_projected[view_idx][j, 0]],
            y=[obs[j, 1], est_projected[view_idx][j, 1]],
            mode="lines",
            line=dict(color="steelblue", width=0.5),
            showlegend=False,
            hoverinfo="skip",
        ),
        row=1,
        col=2,
    )

fig_view.update_xaxes(title_text="u (pixels)")
fig_view.update_yaxes(title_text="v (pixels)", autorange="reversed")
fig_view.update_layout(
    height=400,
    margin=dict(t=40, b=80, l=60, r=40),
    legend=dict(orientation="h", yanchor="top", y=-0.15, xanchor="center", x=0.5),
)
HTML(fig_view.to_html(full_html=False, include_plotlyjs="cdn"))

Undistortion

Apply the estimated distortion parameters to rectify the images:

Hide code cell source

def undistort_image(img: np.ndarray, intrinsics: jax.Array) -> np.ndarray:
    """Undistort an image using the estimated intrinsics and scipy.ndimage.map_coordinates.

    Args:
        img: Input image (H, W, 3) or (H, W)
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]

    Returns:
        Undistorted image with same shape as input
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = [float(x) for x in intrinsics]
    h, w = img.shape[:2]

    # Create grid of output pixel coordinates (in undistorted image)
    u, v = np.meshgrid(np.arange(w), np.arange(h))

    # Convert to undistorted normalized coordinates.
    x_n = (u - cx) / fx
    y_n = (v - cy) / fy

    # Apply forward distortion to find where to sample from in the distorted input.
    r2 = x_n**2 + y_n**2
    radial = 1.0 + k1 * r2 + k2 * r2**2
    dx_t = 2 * p1 * x_n * y_n + p2 * (r2 + 2 * x_n**2)
    dy_t = p1 * (r2 + 2 * y_n**2) + 2 * p2 * x_n * y_n
    x_d = x_n * radial + dx_t
    y_d = y_n * radial + dy_t

    # Convert to pixel coordinates in the distorted input image.
    u_src = fx * x_d + cx
    v_src = fy * y_d + cy

    # Sample from source image using map_coordinates.
    if len(img.shape) == 3:
        # Color image - process each channel.
        undistorted = np.zeros_like(img)
        for c in range(3):
            undistorted[:, :, c] = ndimage.map_coordinates(
                img[:, :, c], [v_src, u_src], order=1, mode="constant", cval=0
            )
    else:
        undistorted = ndimage.map_coordinates(
            img, [v_src, u_src], order=1, mode="constant", cval=0
        )

    return undistorted


def compute_distortion_at_points(
    intrinsics: jax.Array, points: jax.Array
) -> np.ndarray:
    """Compute distortion magnitude at specific pixel locations.

    Args:
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]
        points: Pixel coordinates (N, 2)

    Returns:
        Distortion magnitude at each point (N,)
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = [float(x) for x in intrinsics]
    u, v = points[:, 0], points[:, 1]

    x_n = (u - cx) / fx
    y_n = (v - cy) / fy
    r2 = x_n**2 + y_n**2
    radial = 1.0 + k1 * r2 + k2 * r2**2
    x_d = x_n * radial
    y_d = y_n * radial
    u_d = fx * x_d + cx
    v_d = fy * y_d + cy

    return np.sqrt((u_d - u) ** 2 + (v_d - v) ** 2)


def compute_distortion_magnitude(
    intrinsics: jax.Array, shape: tuple[int, int]
) -> np.ndarray:
    """Compute per-pixel distortion magnitude in pixels.

    Args:
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]
        shape: Image shape (height, width)

    Returns:
        Distortion magnitude map (H, W) in pixels
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = [float(x) for x in intrinsics]
    h, w = shape

    u, v = np.meshgrid(np.arange(w), np.arange(h))

    # Normalized coordinates (undistorted)
    x_n = (u - cx) / fx
    y_n = (v - cy) / fy

    # Apply distortion.
    r2 = x_n**2 + y_n**2
    radial = 1.0 + k1 * r2 + k2 * r2**2
    dx_t = 2 * p1 * x_n * y_n + p2 * (r2 + 2 * x_n**2)
    dy_t = p1 * (r2 + 2 * y_n**2) + 2 * p2 * x_n * y_n

    x_d = x_n * radial + dx_t
    y_d = y_n * radial + dy_t

    # Convert back to pixels.
    u_d = fx * x_d + cx
    v_d = fy * y_d + cy

    # Displacement magnitude.
    return np.sqrt((u_d - u) ** 2 + (v_d - v) ** 2)


# Show original vs undistorted + distortion map.
sample_idx = 1
img_orig = cv2.imread(str(image_paths[valid_image_indices[sample_idx]]))
img_orig_rgb = cv2.cvtColor(img_orig, cv2.COLOR_BGR2RGB)
img_undist = undistort_image(img_orig_rgb, est_intrinsics)
distortion_map = compute_distortion_magnitude(est_intrinsics, img_orig.shape[:2])

# Compute distortion at observation locations for this view.
obs_distortion = compute_distortion_at_points(
    est_intrinsics, observations_2d[sample_idx]
)

fig_undist = make_subplots(
    rows=1,
    cols=3,
    subplot_titles=(
        "Original",
        "Undistorted",
        f"Distortion map (corners: {obs_distortion.min():.1f}-{obs_distortion.max():.1f}px)",
    ),
)

fig_undist.add_trace(go.Image(z=img_orig_rgb), row=1, col=1)
fig_undist.add_trace(go.Image(z=img_undist), row=1, col=2)
fig_undist.add_trace(
    go.Heatmap(
        z=distortion_map,
        colorscale="Hot",
        showscale=True,
        colorbar=dict(title="px", len=0.8, x=1.02),
    ),
    row=1,
    col=3,
)
# Overlay observation locations on distortion map.
fig_undist.add_trace(
    go.Scatter(
        x=observations_2d[sample_idx][:, 0],
        y=observations_2d[sample_idx][:, 1],
        mode="markers",
        marker=dict(size=4, color="cyan", symbol="circle"),
        showlegend=False,
        hovertemplate="Distortion: %{text:.1f}px<extra></extra>",
        text=obs_distortion,
    ),
    row=1,
    col=3,
)

fig_undist.update_xaxes(showticklabels=False)
fig_undist.update_yaxes(showticklabels=False, autorange="reversed", row=1, col=3)
fig_undist.update_layout(
    height=280,
    margin=dict(t=40, b=20, l=20, r=40),
)
HTML(fig_undist.to_html(full_html=False, include_plotlyjs="cdn"))

The optimization calibrated the camera from checkerboard images:

  • Top-left: Reprojection error distribution before (red) and after (blue) optimization

  • Top-right: Top-down view of estimated camera positions relative to the chessboard

  • Bottom: Single view comparison showing observed corners (green) vs projected (x markers)

For more details, see: